{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "execution": {},
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/NeuromatchAcademy/course-content/blob/master/tutorials/W3D2_HiddenDynamics/W3D2_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/NeoNeuron/professional-workshop-3/master/tutorials/W8_HiddenDynamics/W8_Tutorial3.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "#  Tutorial 3: The Kalman Filter\n",
    "\n",
    "__Content creators:__ Itzel Olivos Castillo and Xaq Pitkow\n",
    "\n",
    "__Content modified:__ Kai Chen\n",
    "\n",
    "\n",
    "**Useful reference:**\n",
    "- Roweis, Ghahramani (1998): A unifying review of linear Gaussian Models\n",
    "- Bishop (2006): Pattern Recognition and Machine Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Tutorial Objectives\n",
    "\n",
    "In previous tutorials we used Hidden Markov Models (HMM) to infer *discrete* latent states from a sequence of measurements. In this tutorial, we will learn how to infer a *continuous* latent variable using the Kalman filter, which is one version of an HMM.\n",
    "\n",
    "In this tutorial, you will:\n",
    "* Review linear dynamical systems\n",
    "* Learn about the Kalman filter in one dimension\n",
    "* Manipulate parameters of a process to see how the Kalman filter behaves\n",
    "* Think about some core properties of the Kalman filter.\n",
    "\n",
    "You can imagine this inference process happening as Mission Control tries to locate and track Astrocat. But you can also imagine that the brain is using an analogous Hidden Markov Model to track objects in the world, or to estimate the consequences of its own actions. And you could use this technique to estimate brain activity from noisy measurements, for understanding or for building a brain-machine interface."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Imports\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import transforms\n",
    "from collections import namedtuple\n",
    "from scipy.stats import norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "#@title Figure Settings\n",
    "import ipywidgets as widgets       # interactive display\n",
    "from ipywidgets import interactive, interact, HBox, Layout,VBox\n",
    "from IPython.display import HTML\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "\n",
    "nma_style = {\n",
    "    'figure.figsize' : (8, 6),\n",
    "    'figure.autolayout' : True,\n",
    "    'font.size' : 15,\n",
    "    'xtick.labelsize' : 'small',\n",
    "    'ytick.labelsize' : 'small',\n",
    "    'legend.fontsize' : 'small',\n",
    "    'axes.spines.top' : False,\n",
    "    'axes.spines.right' : False,\n",
    "    'xtick.major.size' : 5,\n",
    "    'ytick.major.size' : 5,\n",
    "}\n",
    "for key, value in nma_style.items():\n",
    "    plt.rcParams[key] = value\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Plotting Functions\n",
    "\n",
    "def visualize_Astrocat(s, T):\n",
    "  plt.plot(s, color='limegreen', lw=2)\n",
    "  plt.plot([T], [s[-1]], marker='o', markersize=8, color='limegreen')\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel('s(t)')\n",
    "\n",
    "\n",
    "def plot_measurement(s, m, T):\n",
    "  plt.plot(s, color='limegreen', lw=2, label='true position')\n",
    "  plt.plot([T], [s[-1]], marker='o', markersize=8, color='limegreen')\n",
    "  plt.plot(m, '.', color='crimson', lw=2, label='measurement')\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel('s(t)')\n",
    "  plt.legend()\n",
    "  plt.show()\n",
    "\n",
    "def plot_function(u=1,v=2,w=3,x=4,y=5,z=6):\n",
    "    time=np.arange(0,1,0.01)\n",
    "    df=pd.DataFrame({\"Y1\":np.sin(time*u*2*np.pi),\"y2\":np.sin(time*v*2*np.pi),\"y3\":np.sin(time*w*2*np.pi),\n",
    "                    \"y4\":np.sin(time*x*2*np.pi),\"y5\":np.sin(time*y*2*np.pi),\"y6\":np.sin(time*z*2*np.pi)})\n",
    "    df.plot()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @title Helper Functions\n",
    "\n",
    "gaussian = namedtuple('Gaussian', ['mean', 'cov'])\n",
    "\n",
    "def filter(D, process_noise, measurement_noise, posterior, m):\n",
    "    todays_prior = gaussian(D * posterior.mean, D**2 * posterior.cov + process_noise)\n",
    "    likelihood = gaussian(m, measurement_noise)\n",
    "\n",
    "    info_prior = 1/todays_prior.cov\n",
    "    info_likelihood = 1/likelihood.cov\n",
    "    info_posterior = info_prior + info_likelihood\n",
    "\n",
    "    prior_weight = info_prior / info_posterior\n",
    "    likelihood_weight = info_likelihood / info_posterior\n",
    "    posterior_mean = prior_weight * todays_prior.mean  +  likelihood_weight * likelihood.mean\n",
    "\n",
    "    posterior_cov = 1/info_posterior\n",
    "    todays_posterior = gaussian(posterior_mean, posterior_cov)\n",
    "    \"\"\"\n",
    "    prior = gaussian(belief.mean, belief.cov)\n",
    "\n",
    "    predicted_estimate = D * belief.mean\n",
    "    predicted_covariance = D**2 * belief.cov + process_noise\n",
    "\n",
    "    likelihood = gaussian(m, measurement_noise)\n",
    "    innovation_estimate = m - predicted_estimate\n",
    "    innovation_covariance = predicted_covariance + measurement_noise\n",
    "\n",
    "    K = predicted_covariance / innovation_covariance  # Kalman gain, i.e. the weight given to the difference between the measurement and predicted measurement\n",
    "    updated_mean = predicted_estimate + K * innovation_estimate\n",
    "    updated_cov = (1 - K) * predicted_covariance\n",
    "    todays_posterior = gaussian(updated_mean, updated_cov)\n",
    "    \"\"\"\n",
    "    return todays_prior, likelihood, todays_posterior\n",
    "\n",
    "\n",
    "def paintMyFilter(D, initial_guess, process_noise, measurement_noise, s, m, s_, cov_):\n",
    "    # Compare solution with filter function\n",
    "\n",
    "    filter_s_ = np.zeros(T)    # estimate (posterior mean)\n",
    "    filter_cov_ = np.zeros(T)    # uncertainty (posterior covariance)\n",
    "\n",
    "    posterior = initial_guess\n",
    "    filter_s_[0] = posterior.mean\n",
    "    filter_cov_[0] = posterior.cov\n",
    "\n",
    "    process_noise_std = np.sqrt(process_noise)\n",
    "    measurement_noise_std = np.sqrt(measurement_noise)\n",
    "\n",
    "    for i in range(1, T):\n",
    "        prior, likelihood, posterior = filter(D, process_noise, measurement_noise, posterior, m[i])\n",
    "        filter_s_[i] =  posterior.mean\n",
    "        filter_cov_[i] = posterior.cov\n",
    "\n",
    "    smin = min(min(m),min(s-2*np.sqrt(cov_[-1])),min(s_-2*np.sqrt(cov_[-1])))\n",
    "    smax = max(max(m),max(s+2*np.sqrt(cov_[-1])),max(s_+2*np.sqrt(cov_[-1])))\n",
    "    pscale = 0.2  # scaling factor for displaying pdfs\n",
    "\n",
    "    fig = plt.figure(figsize=[15, 10])\n",
    "    ax = plt.subplot(2, 1, 1)\n",
    "    ax.set_xlabel('time')\n",
    "    ax.set_ylabel('state')\n",
    "    ax.set_xlim([0, T+(T*pscale)])\n",
    "    ax.set_ylim([smin, smax])\n",
    "\n",
    "    ax.plot(t, s, color='limegreen', lw=2, label='Astrocat´s trajectory')\n",
    "    ax.plot([t[-1]], [s[-1]], marker='o', markersize=8, color='limegreen')\n",
    "\n",
    "    ax.plot(t, m, '.', color='crimson', lw=2, label='measurements')\n",
    "    ax.plot([t[-1]], [m[-1]], marker='o', markersize=8, color='crimson')\n",
    "\n",
    "    ax.plot(t, filter_s_, color='black', lw=2, label='correct estimated trajectory')\n",
    "    ax.plot([t[-1]], [filter_s_[-1]], marker='o', markersize=8, color='black')\n",
    "\n",
    "    res = '! :)' if np.mean((s_ - filter_s_)**2) < 0.1 else ' :('\n",
    "    ax.plot(t, s_, '--', color='lightgray', lw=2, label='your estimated trajectory' + res)\n",
    "    ax.plot([t[-1]], [s_[-1]], marker='o', markersize=8, color='lightgray')\n",
    "\n",
    "    plt.legend()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 1: Astrocat Dynamics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 1.1: Simulating Astrocat's movements"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Coding Exercise 1.1: Simulating Astrocat's movements\n",
    "\n",
    "First, you will simulate how Astrocat moves based on stochastic linear dynamics.\n",
    "\n",
    "The linear dynamical system $s_t = Ds_{t-1} + w_{t-1}$ determines Astrocat's position $s_t$. $D$ is a scalar that models how Astrocat would like to change its position over time, and $w_t \\sim \\mathcal{N}(0, \\sigma_p^2)$ is white Gaussian noise caused by unreliable actuators in Astrocat's propulsion unit. \n",
    "\n",
    "Complete the code below to simulate possible trajectories.\n",
    "\n",
    "First, execute the following cell to enable the default parameters we will use in this tutorial.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Fixed params\n",
    "\n",
    "np.random.seed(0)\n",
    "\n",
    "T_max = 200\n",
    "D = 1\n",
    "tau_min = 1\n",
    "tau_max = 50\n",
    "process_noise_min = 0.1\n",
    "process_noise_max = 10\n",
    "measurement_noise_min = 0.1\n",
    "measurement_noise_max = 10\n",
    "\n",
    "unit_process_noise = np.random.randn(T_max)     # compute all N(0, 1) in advance to speed up time slider\n",
    "unit_measurement_noise = np.random.randn(T_max)     # compute all N(0, 1) in advance to speed up time slider"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def simulate(D, s0, sigma_p, T):\n",
    "  \"\"\" Compute the response of the linear dynamical system.\n",
    "\n",
    "  Args:\n",
    "    D (scalar): dynamics multiplier\n",
    "    s0 (scalar): initial postion\n",
    "    sigma_p (scalar): amount of noise in the system (standard deviation)\n",
    "    T (scalar): total duration of the simulation\n",
    "\n",
    "  Returns:\n",
    "    ndarray: `s`: astrocat's trajectory up to time T\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize variables\n",
    "  s = np.zeros(T+1)\n",
    "  s[0] = s0\n",
    "\n",
    "  # Compute the position at time t given the position at time t-1 for all t\n",
    "  # Consider that np.random.normal(mu, sigma) generates a random sample from\n",
    "  # a gaussian with mean = mu and standard deviation = sigma\n",
    "\n",
    "  for t in range(1, len(s)):\n",
    "\n",
    "    ###################################################################\n",
    "    ## Fill out the following then remove\n",
    "    raise NotImplementedError(\"Student exercise: need to implement simulation\")\n",
    "    ###################################################################\n",
    "\n",
    "    # Update position\n",
    "    s[t] = ...\n",
    "\n",
    "  return s\n",
    "\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(0)\n",
    "\n",
    "# Set parameters\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Visualize\n",
    "visualize_Astrocat(s, T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "\n",
    "def simulate(D, s0, sigma_p, T):\n",
    "  \"\"\" Compute the response of the linear dynamical system.\n",
    "\n",
    "  Args:\n",
    "    D (scalar): dynamics multiplier\n",
    "    s0 (scalar): initial postion\n",
    "    sigma_p (scalar): amount of noise in the system (standard deviation)\n",
    "    T (scalar): total duration of the simulation\n",
    "\n",
    "  Returns:\n",
    "    ndarray: `s`: astrocat's trajectory up to time T\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize variables\n",
    "  s = np.zeros(T+1)\n",
    "  s[0] = s0\n",
    "\n",
    "  # Compute the position at time t given the position at time t-1 for all t\n",
    "  # Consider that np.random.normal(mu, sigma) generates a random sample from\n",
    "  # a gaussian with mean = mu and standard deviation = sigma\n",
    "\n",
    "  for t in range(1, len(s)):\n",
    "\n",
    "    # Update position\n",
    "    s[t] = D*s[t-1] + np.random.normal(0, sigma_p)\n",
    "\n",
    "  return s\n",
    "\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(0)\n",
    "\n",
    "# Set parameters\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Visualize\n",
    "visualize_Astrocat(s, T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Interactive Demo 1.1: Playing with Astrocat movement\n",
    "\n",
    "We will use the function you just implemented in a demo, where you can change the value of $D$ and see what happens.\n",
    "\n",
    "\n",
    "1.   What happens when D is large (>1)? Why?\n",
    "2.   What happens when D is a large negative number (<-1)? Why?\n",
    "3.   What about when D is zero?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the demo\n",
    "\n",
    "\n",
    "@widgets.interact(D=widgets.FloatSlider(value=-.5, min=-2, max=2, step=0.1))\n",
    "def plot(D=D):\n",
    "\n",
    "    # Set parameters\n",
    "    T = 50      # total time duration\n",
    "    s0 = 5.     # initial condition of s at time 0\n",
    "    sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "\n",
    "    # Simulate Astrocat\n",
    "    s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "    # Visualize\n",
    "    visualize_Astrocat(s, T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove explanation\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "1) When D is large, the state at time step t will depend heavily on the state at time\n",
    "   step t_1. If we forget about the noise term, D = 2 would mean that the state at each\n",
    "   time step is double the one before! So the state becomes huge and basically explodes towards\n",
    "   infinity.\n",
    "\n",
    "2) If D is a large negative number, the state at time t will be a different sign than the\n",
    "   state at time step t_1. So the state will oscillate over the x axis.\n",
    "\n",
    "3) When D is zero, the state at time t will not depend on the previous state, it will just\n",
    "   be drawn from the noise distribution\n",
    "\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 1.2: Measuring Astrocat's movements\n",
    "\n",
    "*Estimated timing to here from start of tutorial: 10 min*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Coding Exercise 1.2.1: Reading measurements from Astrocat's collar\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "We will estimate Astrocat's actual position using measurements of a noisy sensor attached to its collar. \n",
    "\n",
    "Complete the function below to read measurements from Astrocat's collar. These measurements are correct except for additive Gaussian noise whose standard deviation is given by the input argument `sigma_measurements`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def read_collar(s, sigma_measurements):\n",
    "  \"\"\" Compute the measurements of the noisy sensor attached to astrocat's collar\n",
    "\n",
    "  Args:\n",
    "    s (ndarray): astrocat's true position over time\n",
    "    sigma_measurements (scalar): amount of noise in the sensor (standard deviation)\n",
    "\n",
    "  Returns:\n",
    "    ndarray: `m`: astrocat's position over time according to the sensor\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize variables\n",
    "  m = np.zeros(len(s))\n",
    "\n",
    "  # For all time t, add white Gaussian noise with magnitude sigma_measurements\n",
    "  # Consider that np.random.normal(mu, sigma) generates a random sample from\n",
    "  # a gaussian with mean = mu and standard deviation = sigma\n",
    "\n",
    "  for t in range(len(s)):\n",
    "\n",
    "    ###################################################################\n",
    "    ## Fill out the following then remove\n",
    "    raise NotImplementedError(\"Student exercise: need to implement read_collar function\")\n",
    "    ###################################################################\n",
    "\n",
    "    # Read measurement\n",
    "    m[t] = ...\n",
    "\n",
    "  return m\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(0)\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "sigma_measurements = 4 # amount of noise in astrocat's collar\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Take measurement from collar\n",
    "m = read_collar(s, sigma_measurements)\n",
    "\n",
    "# Visualize\n",
    "plot_measurement(s, m, T)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "def read_collar(s, sigma_measurements):\n",
    "  \"\"\" Compute the measurements of the noisy sensor attached to astrocat's collar\n",
    "\n",
    "  Args:\n",
    "    s (ndarray): astrocat's true position over time\n",
    "    sigma_measurements (scalar): amount of noise in the sensor (standard deviation)\n",
    "\n",
    "  Returns:\n",
    "    ndarray: `m`: astrocat's position over time according to the sensor\n",
    "  \"\"\"\n",
    "\n",
    "  # Initialize variables\n",
    "  m = np.zeros(len(s))\n",
    "\n",
    "  # For all time t, add white Gaussian noise with magnitude sigma_measurements\n",
    "  # Consider that np.random.normal(mu, sigma) generates a random sample from\n",
    "  # a gaussian with mean = mu and standard deviation = sigma\n",
    "\n",
    "  for t in range(len(s)):\n",
    "\n",
    "    # Read measurement\n",
    "    m[t] = s[t] + np.random.normal(0, sigma_measurements)\n",
    "\n",
    "  return m\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(0)\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "sigma_measurements = 4 # amount of noise in astrocat's collar\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Take measurement from collar\n",
    "m = read_collar(s, sigma_measurements)\n",
    "\n",
    "# Visualize\n",
    "plot_measurement(s, m, T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Coding Exercise 1.2.2: Compare true states to measured states\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "Make a scatter plot to see how bad the measurements of Astrocat's collar are. This exercise will show why using only the measures to track Astrocat can be catastrophic.\n",
    "\n",
    "A Kalman filter will solve this problem!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def compare(s, m):\n",
    "  \"\"\" Compute a scatter plot\n",
    "\n",
    "  Args:\n",
    "    s (ndarray): astrocat's true position over time\n",
    "    m (ndarray): astrocat's measured position over time according to the sensor\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "\n",
    "  ###################################################################\n",
    "  ## Fill out the following then remove\n",
    "  raise NotImplementedError(\"Student exercise: need to implement compare function\")\n",
    "  ###################################################################\n",
    "\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(111)\n",
    "  sbounds = 1.1*max(max(np.abs(s)), max(np.abs(m)))\n",
    "  ax.plot([-sbounds, sbounds], [-sbounds, sbounds], 'k')    # plot line of equality\n",
    "  ax.set_xlabel('state')\n",
    "  ax.set_ylabel('measurement')\n",
    "  ax.set_aspect('equal')\n",
    "\n",
    "  # Complete a scatter plot: true state versus measurements\n",
    "  ...\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(0)\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "sigma_measurements = 4 # amount of noise in astrocat's collar\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Take measurement from collar\n",
    "m = read_collar(s, sigma_measurements)\n",
    "\n",
    "# Visualize true vs measured states\n",
    "compare(s,m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "\n",
    "def compare(s, m):\n",
    "  \"\"\" Compute a scatter plot\n",
    "\n",
    "  Args:\n",
    "    s (ndarray): astrocat's true position over time\n",
    "    m (ndarray): astrocat's measured position over time according to the sensor\n",
    "\n",
    "  \"\"\"\n",
    "\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(111)\n",
    "  sbounds = 1.1*max(max(np.abs(s)), max(np.abs(m)))\n",
    "  ax.plot([-sbounds, sbounds], [-sbounds, sbounds], 'k')    # plot line of equality\n",
    "  ax.set_xlabel('state')\n",
    "  ax.set_ylabel('measurement')\n",
    "  ax.set_aspect('equal')\n",
    "\n",
    "  # Complete a scatter plot: true state versus measurements\n",
    "  plt.scatter(s, m, marker='.', color='red', s=100)\n",
    "\n",
    "\n",
    "# Set parameters\n",
    "np.random.seed(0)\n",
    "D = 0.9    # parameter in s(t)\n",
    "T = 50      # total time duration\n",
    "s0 = 5.     # initial condition of s at time 0\n",
    "sigma_p = 2 # amount of noise in the actuators of astrocat's propulsion unit\n",
    "sigma_measurements = 4 # amount of noise in astrocat's collar\n",
    "\n",
    "# Simulate Astrocat\n",
    "s = simulate(D, s0, sigma_p, T)\n",
    "\n",
    "# Take measurement from collar\n",
    "m = read_collar(s, sigma_measurements)\n",
    "\n",
    "# Visualize true vs measured states\n",
    "compare(s,m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Section 2: The Kalman filter\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 2.1: Using the Kalman filter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Interactive Demo 2.1: The Kalman filter in action\n",
    "\n",
    "Next we provide you with an interactive visualization to understand how the Kalman filter works. Play with the sliders to gain an intuition for how the different factors affect the Kalman filter's inferences. You will code the Kalman filter yourself in the next exercise.\n",
    "\n",
    "The sliders:\n",
    "* current time: Kalman filter synthesizes measurements up until this time.\n",
    "* dynamics time constant $\\tau$: this determines the dynamics value, $D=\\exp^{-\\Delta t/\\tau}$ where $\\Delta t$ is the discrete time step (here 1).\n",
    "* process noise: amount of noise in the actuators of astrocat's propulsion unit\n",
    "* observation noise: the noise levels of our measurements (when we read the collar)\n",
    "\n",
    "Some questions to consider:\n",
    "- What affects the predictability of Astrocat?\n",
    "- How does confidence change over time?\n",
    "- What affects the relative weight of the new measurement?\n",
    "- How is the error related to the posterior variance?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the demo. It takes a few seconds to update so please be patient.\n",
    "display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))\n",
    "\n",
    "@widgets.interact(T=widgets.IntSlider(T_max/4, description=\"current time\", min=1, max=T_max-1),\n",
    "                  tau=widgets.FloatSlider(tau_max/2, description='dynamics time constant', min=tau_min, max=tau_max),\n",
    "                  process_noise=widgets.FloatSlider(2, description=\"process noise\", min=process_noise_min, max=process_noise_max),\n",
    "                  measurement_noise=widgets.FloatSlider(3, description=\"observation noise\", min=measurement_noise_min, max=measurement_noise_max),\n",
    "                  flag_s = widgets.Checkbox(value=True, description='state', disabled=True, indent=False),\n",
    "                  flag_m = widgets.Checkbox(value=False, description='measurement', disabled=False, indent=False),\n",
    "                  flag_s_ = widgets.Checkbox(value=False, description='estimate', disabled=False, indent=False),\n",
    "                  flag_err_ = widgets.Checkbox(value=False, description='estimator confidence intervals', disabled=False, indent=False))\n",
    "\n",
    "def stochastic_system(T, tau, process_noise, measurement_noise, flag_m, flag_s_, flag_err_):\n",
    "    t = np.arange(0, T_max, 1)              # timeline\n",
    "    s = np.zeros(T_max)                     # states\n",
    "    D = np.exp(-1/tau)                      # dynamics multiplier (matrix if s is vector)\n",
    "    process_noise_cov = process_noise**2\n",
    "    measurement_noise_cov = measurement_noise**2\n",
    "\n",
    "    prior_mean = 0\n",
    "    prior_cov = process_noise_cov/(1-D**2)\n",
    "\n",
    "    s[0] = np.sqrt(prior_cov) * unit_process_noise[0]   # Sample initial condition from equilibrium distribution\n",
    "\n",
    "    m = np.zeros(T_max)    # measurement\n",
    "    s_ = np.zeros(T_max)    # estimate (posterior mean)\n",
    "    cov_ = np.zeros(T_max)    # uncertainty (posterior covariance)\n",
    "\n",
    "    s_[0] = prior_mean\n",
    "    cov_[0] = prior_cov\n",
    "    posterior = gaussian(prior_mean, prior_cov)\n",
    "\n",
    "    captured_prior = None\n",
    "    captured_likelihood = None\n",
    "    captured_posterior = None\n",
    "\n",
    "    onfilter = True\n",
    "    for i in range(1, T_max):\n",
    "        s[i] = D * s[i-1] + process_noise * unit_process_noise[i-1]\n",
    "\n",
    "        if onfilter:\n",
    "          m[i] = s[i] + measurement_noise * unit_measurement_noise[i]\n",
    "\n",
    "          prior, likelihood, posterior = filter(D, process_noise_cov, measurement_noise_cov, posterior, m[i])\n",
    "\n",
    "          s_[i] =  posterior.mean\n",
    "          cov_[i] = posterior.cov\n",
    "\n",
    "        if i == T:\n",
    "          onfilter = False\n",
    "          captured_prior = prior\n",
    "          captured_likelihood = likelihood\n",
    "          captured_posterior = posterior\n",
    "\n",
    "    smin = min(min(m),min(s-2*np.sqrt(cov_[-1])),min(s_-2*np.sqrt(cov_[-1])))\n",
    "    smax = max(max(m),max(s+2*np.sqrt(cov_[-1])),max(s_+2*np.sqrt(cov_[-1])))\n",
    "    pscale = 0.2  # scaling factor for displaying pdfs\n",
    "\n",
    "    fig = plt.figure(figsize=[15, 10])\n",
    "    ax = plt.subplot(2, 1, 1)\n",
    "    ax.set_xlabel('time')\n",
    "    ax.set_ylabel('state')\n",
    "    ax.set_xlim([0, T_max+(T_max*pscale)])\n",
    "    ax.set_ylim([smin, smax])\n",
    "\n",
    "    show_pdf = [False, False]\n",
    "    ax.plot(t[:T+1], s[:T+1], color='limegreen', lw=2)\n",
    "    ax.plot(t[T:], s[T:], color='limegreen', lw=2, alpha=0.3)\n",
    "    ax.plot([t[T:T+1]], [s[T:T+1]], marker='o', markersize=8, color='limegreen')\n",
    "\n",
    "    if flag_m:\n",
    "        ax.plot(t[:T+1], m[:T+1], '.', color='crimson', lw=2)\n",
    "        ax.plot([t[T:T+1]], [m[T:T+1]], marker='o', markersize=8, color='crimson')\n",
    "\n",
    "        domain = np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], 500)\n",
    "        pdf_likelihood = norm.pdf(domain, captured_likelihood.mean, np.sqrt(captured_likelihood.cov))\n",
    "        ax.fill_betweenx(domain, T + pdf_likelihood*(T_max*pscale), T, color='crimson', alpha=0.5, label='likelihood', edgecolor=\"crimson\", linewidth=0)\n",
    "        ax.plot(T + pdf_likelihood*(T_max*pscale), domain, color='crimson', linewidth=2.0)\n",
    "\n",
    "        ax.legend(ncol=3, loc='upper left')\n",
    "        show_pdf[0] = True\n",
    "\n",
    "    if flag_s_:\n",
    "        ax.plot(t[:T+1], s_[:T+1], color='black', lw=2)\n",
    "        ax.plot([t[T:T+1]], [s_[T:T+1]], marker='o', markersize=8, color='black')\n",
    "        show_pdf[1] = True\n",
    "\n",
    "    if flag_err_:\n",
    "        ax.fill_between(t[:T+1], s_[:T+1] + 2 * np.sqrt(cov_)[:T+1], s_[:T+1] - 2 * np.sqrt(cov_)[:T+1], color='black', alpha=0.3)\n",
    "        show_pdf[1] = True\n",
    "\n",
    "    if show_pdf[1]:\n",
    "        domain = np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], 500)\n",
    "        pdf_post = norm.pdf(domain, captured_posterior.mean, np.sqrt(captured_posterior.cov))\n",
    "        ax.fill_betweenx(domain, T + pdf_post*(T_max*pscale), T, color='black', alpha=0.5, label='posterior', edgecolor=\"black\", linewidth=0)\n",
    "        ax.plot(T + pdf_post*(T_max*pscale), domain, color='black', linewidth=2.0)\n",
    "        ax.legend(ncol=3, loc='upper left')\n",
    "\n",
    "    if show_pdf[0] and show_pdf[1]:\n",
    "        domain = np.linspace(ax.get_ylim()[0], ax.get_ylim()[1], 500)\n",
    "        pdf_prior = norm.pdf(domain, captured_prior.mean, np.sqrt(captured_prior.cov))\n",
    "        ax.fill_betweenx(domain, T + pdf_prior*(T_max*pscale), T, color='dodgerblue', alpha=0.5, label='prior', edgecolor=\"dodgerblue\", linewidth=0)\n",
    "        ax.plot(T + pdf_prior*(T_max*pscale), domain, color='dodgerblue', linewidth=2.0)\n",
    "        ax.legend(ncol=3, loc='upper left')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Coding Exercise 2.1: Implement your own Kalman filter\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "As you saw in the video and the previous exercise, a Kalman filter estimates a posterior probability distribution *recursively* over time using a mathematical model of the process and incoming measurements. This dynamic posterior allows us to improve our guess about Astrocat's position as new measures arrive; besides, its mean is the best estimate one can compute of Astrocat's actual position at each time step.\n",
    "\n",
    "Now it's your turn! Follow this recipe to complete the code below and implement your own Kalman filter:\n",
    "\n",
    "**Step 1: Change yesterday's posterior into today's prior** \n",
    "\n",
    "Use the mathematical model to calculate how deterministic changes in the process shift yesterday's posterior, $\\mathcal{N}(\\mu_{s_{t-1}}, \\sigma_{s_{t-1}}^2)$, and how random changes in the process broaden the shifted distribution:\n",
    "\n",
    "> $p(s_t|m_{1:t-1}) = p(Ds_{t-1}+w_{t-1} | m_{1:t-1}) = \\mathcal{N}(D\\mu_{s_{t-1}} + 0, D^2\\sigma_{s_{t-1}}^2 +\\sigma_p^2)$\n",
    "\n",
    "Note that we use $\\sigma_p$ here to denote the process noise, while the video used $\\sigma_w$ (a change in notation to sync with the prior sections).\n",
    "\n",
    "**Step 2: Multiply today's prior by likelihood** \n",
    "\n",
    "Use the latest measurement of Astrocat's collar (fresh evidence) to form a new estimate somewhere between this measurement and what we predicted in Step 1. The next posterior is the result of multiplying the Gaussian computed in Step 1 (a.k.a. today's prior) and the likelihood, which is also modeled as a Gaussian $\\mathcal{N}(m_t, \\sigma_m^2)$:\n",
    "\n",
    "**2a: add information from prior and likelihood** \n",
    "\n",
    "To find the posterior variance, we first compute the posterior information (which is the inverse of the variance) by adding the information provided by the prior and the likelihood:\n",
    "\n",
    "> $\\frac{1}{\\sigma_{s_t}^2} = \\frac{1}{D^2\\sigma_{s_{t-1}}^2 +\\sigma_p^2} + \\frac{1}{\\sigma_m^2} $\n",
    "\n",
    "Now we can take the inverse of the posterior information to get back the posterior variance.\n",
    "\n",
    "**2b: add means from prior and likelihood** \n",
    "\n",
    "To find the posterior mean, we calculate a weighted average of means from prior and likelihood, where each weight, $g$, is just the fraction of information that each Gaussian provides!\n",
    "\n",
    "> $g_{\\rm{prior}} = \\frac{\\rm{information}_{\\textit{ }\\rm{prior}}}{\\rm{information}_{\\textit{ }\\rm{posterior}}}$\n",
    ">\n",
    "> $g_{\\rm{likelihood}} = \\frac{\\rm{information}_{\\textit{ }\\rm{likelihood}}}{\\rm{information}_{\\textit{ }\\rm{posterior}}}$\n",
    ">\n",
    "> $\\bar{\\mu}_t = g_{\\rm{prior}} D\\mu_{s_{t-1}} + g_{\\rm{likelihood}} m_t$ \n",
    "    \n",
    "Congrats!\n",
    "\n",
    "**Implementation detail:** You can access the statisics of a Gaussian by typing, e.g., \n",
    "\n",
    "```\n",
    "prior.mean\n",
    "prior.cov\n",
    "```\n",
    "\n",
    "**Optional: Relationship to classic description of Kalman filter:**\n",
    "\n",
    "We're teaching this recipe because it is interpretable and connects to past lessons about the sum rule and product rule for Gaussians. But the classic description of the Kalman filter is a little different. The above weights, $g_{\\rm{prior}}$ and $g_{\\rm{likelihood}}$, add up to $1$ and can be written one in terms of the other; then, if we let $K = g_{\\rm{likelihood}}$, the posterior mean can be expressed as:\n",
    "\n",
    "$\\bar{\\mu}_t = (1-K) D\\bar{\\mu}_{t-1} + K m_t = D\\bar{\\mu}_{t-1} + K (m_t - D\\bar{\\mu}_{t-1})$ \n",
    "\n",
    "In classic textbooks, you will often find this expression for the posterior mean; $K$ is known as the Kalman gain and its function is to choose a value partway between the current measurement $m_t$ and the prediction from Step 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Set random seed\n",
    "np.random.seed(0)\n",
    "\n",
    "# Set parameters\n",
    "T = 50                  # Time duration\n",
    "tau = 25                # dynamics time constant\n",
    "process_noise = 2       # process noise in Astrocat's propulsion unit (standard deviation)\n",
    "measurement_noise = 9   # measurement noise in Astrocat's collar (standard deviation)\n",
    "\n",
    "# Auxiliary variables\n",
    "process_noise_cov = process_noise**2          # process noise in Astrocat's propulsion unit (variance)\n",
    "measurement_noise_cov = measurement_noise**2  # measurement noise in Astrocat's collar (variance)\n",
    "\n",
    "# Initialize arrays\n",
    "t = np.arange(0, T, 1)   # timeline\n",
    "s = np.zeros(T)          # states\n",
    "D = np.exp(-1/tau)       # dynamics multiplier (matrix if s is vector)\n",
    "\n",
    "m = np.zeros(T)          # measurement\n",
    "s_ = np.zeros(T)         # estimate (posterior mean)\n",
    "cov_ = np.zeros(T)       # uncertainty (posterior covariance)\n",
    "\n",
    "# Initial guess of the posterior at time 0\n",
    "initial_guess = gaussian(0, process_noise_cov/(1-D**2))    # In this case, the initial guess (posterior distribution\n",
    "                                                           # at time 0) is the equilibrium distribution, but feel free to\n",
    "                                                           # experiment with other gaussians\n",
    "posterior = initial_guess\n",
    "\n",
    "# Sample initial conditions\n",
    "s[0] = posterior.mean + np.sqrt(posterior.cov) * np.random.randn()   # Sample initial condition from posterior distribution at time 0\n",
    "s_[0] = posterior.mean\n",
    "cov_[0] = posterior.cov\n",
    "\n",
    "# Loop over steps\n",
    "for i in range(1, T):\n",
    "\n",
    "    # Sample true states and corresponding measurements\n",
    "    s[i] = D * s[i-1] + np.random.normal(0, process_noise)    # variable `s` records the true position of Astrocat\n",
    "    m[i] = s[i] + np.random.normal(0, measurement_noise)      # variable `m` records the measurements of Astrocat's collar\n",
    "\n",
    "    ###################################################################\n",
    "    ## Fill out the following then remove\n",
    "    raise NotImplementedError(\"Student exercise: need to implement the Kalman filter\")\n",
    "    ###################################################################\n",
    "\n",
    "    # Step 1. Shift yesterday's posterior to match the deterministic change of the system's dynamics,\n",
    "    #         and broad it to account for the random change (i.e., add mean and variance of process noise).\n",
    "    todays_prior = ...\n",
    "\n",
    "    # Step 2. Now that yesterday's posterior has become today's prior, integrate new evidence\n",
    "    #         (i.e., multiply gaussians from today's prior and likelihood)\n",
    "    likelihood = ...\n",
    "\n",
    "    # Step 2a:  To find the posterior variance, add informations (inverse variances) of prior and likelihood\n",
    "    info_prior = 1/todays_prior.cov\n",
    "    info_likelihood = 1/likelihood.cov\n",
    "    info_posterior = ...\n",
    "\n",
    "    # Step 2b: To find the posterior mean, calculate a weighted average of means from prior and likelihood;\n",
    "    #          the weights are just the fraction of information that each gaussian provides!\n",
    "    prior_weight = info_prior / info_posterior\n",
    "    likelihood_weight = info_likelihood / info_posterior\n",
    "    posterior_mean = ...\n",
    "\n",
    "    # Don't forget to convert back posterior information to posterior variance!\n",
    "    posterior_cov = 1/info_posterior\n",
    "    posterior = gaussian(posterior_mean, posterior_cov)\n",
    "\n",
    "    s_[i] = posterior.mean\n",
    "    cov_[i] = posterior.cov\n",
    "\n",
    "# Visualize\n",
    "paintMyFilter(D, initial_guess, process_noise_cov, measurement_noise_cov, s, m, s_, cov_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# to_remove solution\n",
    "\n",
    "# Set random seed\n",
    "np.random.seed(0)\n",
    "\n",
    "# Set parameters\n",
    "T = 50                  # Time duration\n",
    "tau = 25                # dynamics time constant\n",
    "process_noise = 2       # process noise in Astrocat's propulsion unit (standard deviation)\n",
    "measurement_noise = 9   # measurement noise in Astrocat's collar (standard deviation)\n",
    "\n",
    "# Auxiliary variables\n",
    "process_noise_cov = process_noise**2          # process noise in Astrocat's propulsion unit (variance)\n",
    "measurement_noise_cov = measurement_noise**2  # measurement noise in Astrocat's collar (variance)\n",
    "\n",
    "# Initialize arrays\n",
    "t = np.arange(0, T, 1)   # timeline\n",
    "s = np.zeros(T)          # states\n",
    "D = np.exp(-1/tau)       # dynamics multiplier (matrix if s is vector)\n",
    "\n",
    "m = np.zeros(T)          # measurement\n",
    "s_ = np.zeros(T)         # estimate (posterior mean)\n",
    "cov_ = np.zeros(T)       # uncertainty (posterior covariance)\n",
    "\n",
    "# Initial guess of the posterior at time 0\n",
    "initial_guess = gaussian(0, process_noise_cov/(1-D**2))    # In this case, the initial guess (posterior distribution\n",
    "                                                           # at time 0) is the equilibrium distribution, but feel free to\n",
    "                                                           # experiment with other gaussians\n",
    "posterior = initial_guess\n",
    "\n",
    "# Sample initial conditions\n",
    "s[0] = posterior.mean + np.sqrt(posterior.cov) * np.random.randn()   # Sample initial condition from posterior distribution at time 0\n",
    "s_[0] = posterior.mean\n",
    "cov_[0] = posterior.cov\n",
    "\n",
    "# Loop over steps\n",
    "for i in range(1, T):\n",
    "\n",
    "    # Sample true states and corresponding measurements\n",
    "    s[i] = D * s[i-1] + np.random.normal(0, process_noise)    # variable `s` records the true position of Astrocat\n",
    "    m[i] = s[i] + np.random.normal(0, measurement_noise)      # variable `m` records the measurements of Astrocat's collar\n",
    "\n",
    "    # Step 1. Shift yesterday's posterior to match the deterministic change of the system's dynamics,\n",
    "    #         and broad it to account for the random change (i.e., add mean and variance of process noise).\n",
    "    todays_prior = gaussian(D * posterior.mean, D**2 * posterior.cov + process_noise_cov)\n",
    "\n",
    "    # Step 2. Now that yesterday's posterior has become today's prior, integrate new evidence\n",
    "    #         (i.e., multiply gaussians from today's prior and likelihood)\n",
    "    likelihood = gaussian(m[i], measurement_noise_cov)\n",
    "\n",
    "    # Step 2a:  To find the posterior variance, add informations (inverse variances) of prior and likelihood\n",
    "    info_prior = 1/todays_prior.cov\n",
    "    info_likelihood = 1/likelihood.cov\n",
    "    info_posterior = info_prior + info_likelihood\n",
    "\n",
    "    # Step 2b: To find the posterior mean, calculate a weighted average of means from prior and likelihood;\n",
    "    #          the weights are just the fraction of information that each gaussian provides!\n",
    "    prior_weight = info_prior / info_posterior\n",
    "    likelihood_weight = info_likelihood / info_posterior\n",
    "    posterior_mean = prior_weight * todays_prior.mean  +  likelihood_weight * likelihood.mean\n",
    "\n",
    "    # Don't forget to convert back posterior information to posterior variance!\n",
    "    posterior_cov = 1/info_posterior\n",
    "    posterior = gaussian(posterior_mean, posterior_cov)\n",
    "\n",
    "    s_[i] = posterior.mean\n",
    "    cov_[i] = posterior.cov\n",
    "\n",
    "# Visualize\n",
    "paintMyFilter(D, initial_guess, process_noise_cov, measurement_noise_cov, s, m, s_, cov_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 2.2: Estimation accuracy\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Interactive Demo 2.2: Compare states, estimates, and measurements\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "How well do the estimates $\\hat{s}$ match the actual values $s$? How does the distribution of errors $\\hat{s}_t - s_t$ compare to the posterior variance? Why? Try different parameters of the Hidden Markov Model and observe how the properties change.\n",
    "\n",
    "How do the _measurements_ $m$ compare to the true states?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute cell to enable the demo\n",
    "display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))\n",
    "\n",
    "@widgets.interact(tau=widgets.FloatSlider(tau_max/2, description='tau', min=tau_min, max=tau_max),\n",
    "                  process_noise=widgets.FloatSlider(2, description=\"process noise\", min=process_noise_min, max=process_noise_max),\n",
    "                  measurement_noise=widgets.FloatSlider(3, description=\"observation noise\", min=measurement_noise_min, max=measurement_noise_max),\n",
    "                  flag_m = widgets.Checkbox(value=False, description='measurements', disabled=False, indent=False))\n",
    "\n",
    "def stochastic_system(tau, process_noise, measurement_noise, flag_m):\n",
    "    T = T_max\n",
    "    t = np.arange(0, T_max, 1)              # timeline\n",
    "    s = np.zeros(T_max)                     # states\n",
    "    D = np.exp(-1/tau)                      # dynamics multiplier (matrix if s is vector)\n",
    "\n",
    "    process_noise_cov = process_noise**2          # process noise in Astrocat's propulsion unit (variance)\n",
    "    measurement_noise_cov = measurement_noise**2  # measurement noise in Astrocat's collar (variance)\n",
    "\n",
    "    prior_mean = 0\n",
    "    prior_cov = process_noise_cov/(1-D**2)\n",
    "\n",
    "\n",
    "    s[0] = np.sqrt(prior_cov) * np.random.randn()   # Sample initial condition from equilibrium distribution\n",
    "\n",
    "    m = np.zeros(T_max)    # measurement\n",
    "    s_ = np.zeros(T_max)    # estimate (posterior mean)\n",
    "    cov_ = np.zeros(T_max)    # uncertainty (posterior covariance)\n",
    "\n",
    "    s_[0] = prior_mean\n",
    "    cov_[0] = prior_cov\n",
    "    posterior = gaussian(prior_mean, prior_cov)\n",
    "\n",
    "    for i in range(1, T):\n",
    "        s[i] = D * s[i-1] + process_noise * np.random.randn()\n",
    "        m[i] = s[i] + measurement_noise * np.random.randn()\n",
    "\n",
    "        prior, likelihood, posterior = filter(D, process_noise_cov, measurement_noise_cov, posterior, m[i])\n",
    "\n",
    "        s_[i] =  posterior.mean\n",
    "        cov_[i] = posterior.cov\n",
    "\n",
    "    fig = plt.figure(figsize=[10, 5])\n",
    "    ax = plt.subplot(1, 2, 1)\n",
    "    ax.set_xlabel('s')\n",
    "    ax.set_ylabel('$\\mu$')\n",
    "\n",
    "    sbounds = 1.1*max(max(np.abs(s)), max(np.abs(s_)), max(np.abs(m)))\n",
    "    ax.plot([-sbounds, sbounds], [-sbounds, sbounds], 'k')    # plot line of equality\n",
    "    ax.errorbar(s, s_, yerr=2*np.sqrt(cov_[-1]), marker='.', mfc='black', mec='black', linestyle='none', color='gray')\n",
    "\n",
    "    axhist = plt.subplot(1, 2, 2)\n",
    "    axhist.set_xlabel('error $s-\\hat{s}$')\n",
    "    axhist.set_ylabel('probability')\n",
    "    axhist.hist(s-s_, density=True, bins=25, alpha=.5, label='histogram of estimate errors', color='yellow')\n",
    "\n",
    "    if flag_m:\n",
    "        ax.plot(s, m, marker='.', linestyle='none', color='red')\n",
    "        axhist.hist(s-m,density=True,bins=25,alpha=.5,label='histogram of measurement errors',color='orange')\n",
    "\n",
    "    domain = np.arange(-sbounds, sbounds, 0.1)\n",
    "    pdf_g = norm.pdf(domain, 0, np.sqrt(cov_[-1]))\n",
    "    axhist.fill_between(domain, pdf_g, color='black', alpha=0.5, label=r'posterior shifted to mean')\n",
    "    axhist.legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Section 2.3: Searching for Astrocat\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "### Interactive Demo 2.3: How long does it take to find astrocat?\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "\n",
    "Here we plot the posterior variance as a function of time. Before mission control gets measurements, their only information about astrocat's location is the prior. After some measurements, they hone in on astrocat.\n",
    "* How does the variance shrink with time?\n",
    "* The speed depends on the process dynamics, but does it also depend on the signal-to-noise ratio (SNR)? (Here we measure SNR in decibels, a log scale where 1 dB means 0.1 log unit.)\n",
    "\n",
    "The red curve shows how rapidly the latent variance equilibrates exponentially from an initial condition, with a time constant of $\\sim 1/(1-D^2)$. (**Note:** We adjusted the curve by shifting and scaling so it lines up visually with the posterior equilibrium variance. This makes it easier to compare timescales.) Does the latent process converge faster or slower than the posterior? Can you explain this based on how the Kalman filter integrates evidence?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "form",
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# @markdown Execute this cell to enable the demo\n",
    "\n",
    "display(HTML('''<style>.widget-label { min-width: 15ex !important; }</style>'''))\n",
    "\n",
    "@widgets.interact(T=widgets.IntSlider(tau_max, description=\"max time\", min=2, max=T_max-1),\n",
    "                  tau=widgets.FloatSlider(tau_max/2, description='time constant', min=tau_min, max=tau_max),\n",
    "                  SNRdB=widgets.FloatSlider(-20., description=\"SNR (decibels)\", min=-40., max=10.))\n",
    "\n",
    "def stochastic_system(T, tau, SNRdB):\n",
    "\n",
    "    t = np.arange(0, T, 1)              # timeline\n",
    "    s = np.zeros(T)                     # states\n",
    "    D = np.exp(-1/tau)                  # dynamics matrix (scalar here)\n",
    "    prior_mean = 0\n",
    "    process_noise = 1\n",
    "    SNR = 10**(.1*SNRdB)\n",
    "    measurement_noise = process_noise / SNR\n",
    "    prior_cov = process_noise/(1-D**2)\n",
    "\n",
    "    s[0] = np.sqrt(prior_cov) * unit_process_noise[0]   # Sample initial condition from equilibrium distribution\n",
    "\n",
    "    m = np.zeros(T)    # measurements\n",
    "    s_ = np.zeros(T)    # estimates (posterior mean)\n",
    "    cov_ = np.zeros(T)    # uncertainty (posterior covariance)\n",
    "    pcov = np.zeros(T)    # process covariance\n",
    "\n",
    "    s_[0] = prior_mean\n",
    "    cov_[0] = prior_cov\n",
    "    posterior = gaussian(prior_mean, prior_cov)\n",
    "\n",
    "    for i in range(1, T):\n",
    "        s[i] = D * s[i-1] + np.sqrt(process_noise) * unit_process_noise[i-1]\n",
    "        m[i] = s[i] + np.sqrt(measurement_noise) * unit_measurement_noise[i]\n",
    "\n",
    "        prior, likelihood, posterior = filter(D, process_noise, measurement_noise, posterior, m[i])\n",
    "\n",
    "        s_[i] =  posterior.mean\n",
    "        cov_[i] = posterior.cov\n",
    "        pcov[i] = D**2 * pcov[i-1] + process_noise\n",
    "\n",
    "    equilibrium_posterior_var = process_noise * (D**2 - 1 - SNR + np.sqrt((D**2 - 1 - SNR)**2 + 4 * D**2 * SNR)) / (2 * D**2 * SNR)\n",
    "\n",
    "    equilibrium_process_var = process_noise / (1-D**2)\n",
    "\n",
    "    scale = (max(cov_) - equilibrium_posterior_var) / equilibrium_process_var\n",
    "    pcov = pcov * scale   # scale for better visual comparison of temporal structure\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    ax.set_xlabel('time')\n",
    "    ax.set_xlim([0, T])\n",
    "\n",
    "    ax.fill_between(t, 0, cov_, color='black', alpha=0.3)\n",
    "    ax.plot(t, cov_, color='black', label='posterior variance')\n",
    "    ax.set_ylabel('posterior variance')\n",
    "    ax.set_ylim([0, max(cov_)])\n",
    "\n",
    "    ax2 = ax.twinx()  # instantiate a second axes that shares the same x-axis\n",
    "    ax2.fill_between(t, min(pcov), pcov, color='red', alpha=0.3)\n",
    "    ax2.plot(t, pcov, color='red', label='hidden process variance')\n",
    "    ax2.set_ylabel('hidden process variance (scaled)', color='red', rotation=-90, labelpad=20)\n",
    "\n",
    "    ax2.tick_params(axis='y', labelcolor='red')\n",
    "    # ax2.yaxis.set_major_formatter(plt.FuncFormatter(format_func))\n",
    "    ax2.set_yticks([0, equilibrium_process_var - equilibrium_posterior_var])\n",
    "    ax2.set_yticklabels(['0', 'equilibrium\\nprocess var'])\n",
    "    ax2.set_ylim([max(cov_), 0])\n",
    "\n",
    "    fig.tight_layout()  # otherwise the right y-label is slightly clipped\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "\n",
    "**Applications of Kalman filter in brain science**\n",
    "\n",
    "* Brain-Computer Interface: estimate intended movements using neural activity as measurements.\n",
    "* Data analysis: estimate brain activity from noisy measurements (e.g. EEG)\n",
    "* Model of perception: prey tracking using noisy sensory measurements\n",
    "* Imagine your own! When are you trying to estimate something you cannot see directly?\n",
    "\n",
    "There are many variants that improve upon the limitations of the Kalman filter: non-Gaussian states and measurements, nonlinear dynamics, and more."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "---\n",
    "# Summary\n",
    "\n",
    "In this tutorial, you:\n",
    "- simulated a 1D continuous linear dynamical system and took noisy measurements of the hidden state\n",
    "- used a Kalman filter to recover the hidden states more accurately than if you just used the noisy measurements and connected this to Bayesian ideas\n",
    "- played around with parameters of the process to better understand Kalman filter behavior\n"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W8_Tutorial3",
   "provenance": [],
   "toc_visible": true
  },
  "interpreter": {
   "hash": "9516f62da91337f10c2adbe814d9c63a4b08f8271333386358218606edb781e3"
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3.7.11 64-bit ('pw3': conda)",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
